{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is a simple example of how to use the `AnnubesEnv` class to create a custom environment and use it to train a reinforcement learning agent with `stable_baselines3`.\n",
    "\n",
    "## `AnnubesEnv` environment\n",
    "\n",
    "Let's create an environment, check it works and visualize it.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import neurogym as ngym\n",
    "from neurogym.envs.annubes import AnnubesEnv\n",
    "from stable_baselines3.common.env_checker import check_env\n",
    "\n",
    "env = AnnubesEnv()\n",
    "\n",
    "# check the custom environment and output additional warnings (if any)\n",
    "check_env(env)\n",
    "\n",
    "# check the environment with a random agent\n",
    "obs, info = env.reset()\n",
    "n_steps = 10\n",
    "for _ in range(n_steps):\n",
    "    # random action\n",
    "    action = env.action_space.sample()\n",
    "    obs, reward, terminated, truncated, info = env.step(action)\n",
    "    if terminated:\n",
    "        obs, info = env.reset()\n",
    "\n",
    "print(env.timing)\n",
    "print(\"----------------\")\n",
    "print(env.observation_space)\n",
    "print(env.observation_space.name)\n",
    "print(\"----------------\")\n",
    "print(env.action_space)\n",
    "print(env.action_space.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = ngym.utils.plot_env(\n",
    "    env,\n",
    "    ob_traces=[\"fixation\", \"start\", \"v\", \"a\"],\n",
    "    num_trials=10\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training `AnnubesEnv`\n",
    "\n",
    "### 1. Regular training\n",
    "\n",
    "We can train `AnnubesEnv` using one of the models defined in `stable_baselines3`, for example [`A2C`](https://stable-baselines3.readthedocs.io/en/master/modules/a2c.html).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "from stable_baselines3 import A2C\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "\n",
    "warnings.filterwarnings(\"default\")\n",
    "\n",
    "# train agent\n",
    "env = AnnubesEnv()\n",
    "env_vec = DummyVecEnv([lambda: env])\n",
    "model = A2C(\"MlpPolicy\", env_vec, verbose=0)\n",
    "model.learn(total_timesteps=20000, log_interval=1000)\n",
    "env_vec.close()\n",
    "\n",
    "# plot example trials with trained agent\n",
    "data = ngym.utils.plot_env(env, num_trials=10, ob_traces=[\"fixation\", \"start\", \"v\", \"a\"], model=model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Sequential training\n",
    "\n",
    "We can also train `AnnubesEnv` using a sequential training approach. This is useful when we want to train the agent in multiple stages, each with a different environment configuration. This can be useful for:\n",
    "\n",
    "- **Curriculum learning**: Gradually increase the difficulty of the environments. Start with simpler tasks and progressively move to more complex ones, allowing the agent to build on its previous experiences.\n",
    "\n",
    "- **Domain randomization**: Vary the environment dynamics (e.g., physics, obstacles) during training to improve the agent's robustness to changes in the environment.\n",
    "\n",
    "- **Transfer learning**: If you have access to different agents or architectures, you can use transfer learning techniques to fine-tune the model on a new environment.\n",
    "\n",
    "In this case it is important to include all the possible observations in each environment, even if not all of them are used. This is because the model is initialized with the first environment's observation space and it is not possible to change it later.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env1 = AnnubesEnv({\"v\": 1, \"a\": 0})\n",
    "env1_vec = DummyVecEnv([lambda: env1])\n",
    "# create a model and train it with the first environment\n",
    "model = A2C(\"MlpPolicy\", env1_vec, verbose=0)\n",
    "model.learn(total_timesteps=10000)\n",
    "env1_vec.close()\n",
    "\n",
    "# plot example trials with trained agent\n",
    "data = ngym.utils.plot_env(env1, num_trials=10, ob_traces=[\"fixation\", \"start\", \"v\", \"a\"], model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# switch to the second environment and continue training\n",
    "env2 = AnnubesEnv({\"v\": 0, \"a\": 1})\n",
    "env2_vec = DummyVecEnv([lambda: env2])\n",
    "# set the model's environment to the new environment\n",
    "model.set_env(env2_vec)\n",
    "model.learn(total_timesteps=10000)\n",
    "env2_vec.close()\n",
    "\n",
    "# plot example trials with trained agent\n",
    "data = ngym.utils.plot_env(env2, num_trials=10, ob_traces=[\"fixation\", \"start\", \"v\", \"a\"], model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Switch to the third environment and finish training\n",
    "env3 = AnnubesEnv({\"v\": 0.5, \"a\": 0.5})\n",
    "env3_vec = DummyVecEnv([lambda: env3])\n",
    "# set the model's environment to the new environment\n",
    "model.set_env(env3_vec)\n",
    "model.learn(total_timesteps=20000)\n",
    "env3_vec.close()\n",
    "\n",
    "# plot example trials with trained agent\n",
    "data = ngym.utils.plot_env(env3, num_trials=10, ob_traces=[\"fixation\", \"start\", \"v\", \"a\"], model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the final model after all training\n",
    "model.save(\"final_model\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "neurogym",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
